{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "0cf0ccfa-41da-45c4-a3bb-6358537a85ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pandas as pd\n",
    "import mindspore\n",
    "import mindcv\n",
    "import mindspore.nn as nn\n",
    "import mindspore.dataset.transforms as transforms\n",
    "import mindspore.dataset.vision as vision\n",
    "from d2l import mindspore as d2l"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "1a6a0054-9033-427e-b337-31393b926add",
   "metadata": {},
   "outputs": [],
   "source": [
    "#@save\n",
    "d2l.DATA_HUB['dog_tiny'] = (d2l.DATA_URL + 'kaggle_dog_tiny.zip',\n",
    "                            '0cb91d09b814ecdc07b50f31f8dcad3e81d6a86d')\n",
    "\n",
    "# 如果使用Kaggle比赛的完整数据集，请将下面的变量更改为False\n",
    "demo = True\n",
    "if demo:\n",
    "    data_dir = d2l.download_extract('dog_tiny')\n",
    "else:\n",
    "    data_dir = os.path.join('..', 'data', 'dog-breed-identification')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "87c1d179-52f1-416c-b65b-ab4c80920445",
   "metadata": {},
   "outputs": [],
   "source": [
    "def reorg_dog_data(data_dir, valid_ratio):\n",
    "    labels = d2l.read_csv_labels(os.path.join(data_dir, 'labels.csv'))\n",
    "    d2l.reorg_train_valid(data_dir, labels, valid_ratio)\n",
    "    d2l.reorg_test(data_dir)\n",
    "\n",
    "\n",
    "batch_size = 32 if demo else 128\n",
    "valid_ratio = 0.1\n",
    "reorg_dog_data(data_dir, valid_ratio)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "6ee5a0da-c96b-4216-8a5e-65d97233485f",
   "metadata": {},
   "outputs": [],
   "source": [
    "transform_train = transforms.Compose([\n",
    "    # 随机裁剪图像，所得图像为原始面积的0.08～1之间，高宽比在3/4和4/3之间。\n",
    "    # 然后，缩放图像以创建224x224的新图像\n",
    "    vision.RandomResizedCrop(224, scale=(0.08, 1.0), ratio=(3.0/4.0, 4.0/3.0)),\n",
    "    vision.RandomHorizontalFlip(),\n",
    "    # 随机更改亮度，对比度和饱和度\n",
    "    vision.RandomColorAdjust(brightness=0.4,\n",
    "                                       contrast=0.4,\n",
    "                                       saturation=0.4),\n",
    "     # 标准化图像的每个通道\n",
    "    vision.Normalize(mean=[0.485 * 255, 0.456 * 255, 0.406 * 255], \n",
    "                     std=[0.229 * 255, 0.224 * 255, 0.225 * 255]),\n",
    "    vision.HWC2CHW()])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "76357ad6-d5c8-44de-9c6e-d100d5bde462",
   "metadata": {},
   "outputs": [],
   "source": [
    "transform_test = transforms.Compose([\n",
    "    vision.Resize(256),\n",
    "    # 从图像中心裁切224x224大小的图片\n",
    "    vision.CenterCrop(224),\n",
    "    vision.Normalize([0.485 * 255, 0.456 * 255, 0.406 * 255],\n",
    "                     [0.229 * 255, 0.224 * 255, 0.225 * 255]),\n",
    "    vision.HWC2CHW()])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "9e1cbf9f-52ff-433e-ba1f-58022878de20",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_ds, train_valid_ds = [mindspore.dataset.ImageFolderDataset(\n",
    "    os.path.join(data_dir, 'train_valid_test', folder), shuffle=True, decode=True) \n",
    "                            for folder in ['train', 'train_valid']]\n",
    "train_ds = train_ds.map(transform_train, 'image')\n",
    "train_valid_ds = train_valid_ds.map(transform_train, 'image')\n",
    "\n",
    "\n",
    "valid_ds, test_ds = [mindspore.dataset.ImageFolderDataset(\n",
    "    os.path.join(data_dir, 'train_valid_test', folder), shuffle=False, decode=True) \n",
    "                     for folder in ['valid', 'test']]\n",
    "valid_ds = valid_ds.map(transform_test, 'image')\n",
    "test_ds = test_ds.map(transform_test, 'image')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "cbf6fab9-3b24-43ea-bee9-8deb304267a6",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_iter, train_valid_iter = [dataset.batch(batch_size=batch_size, drop_remainder=True)\n",
    "                                for dataset in (train_ds, train_valid_ds)]\n",
    "\n",
    "valid_iter = valid_ds.batch(batch_size=batch_size, drop_remainder=True)\n",
    "\n",
    "test_iter = test_ds.batch(batch_size=batch_size, drop_remainder=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "807e4813-c846-405d-b787-f38df1cfef21",
   "metadata": {},
   "outputs": [],
   "source": [
    "import mindspore.common.initializer as initializer\n",
    "def get_net(devices):\n",
    "    finetune_net = nn.SequentialCell()\n",
    "    finetune_net.feature = mindcv.create_model('resnet34', pretrained=True)\n",
    "    #finetune_net.append(feature)\n",
    "    # 定义一个新的输出网络，共有120个输出类别\n",
    "    output_new = nn.SequentialCell([nn.Dense(1000, 256),\n",
    "                  nn.ReLU(),\n",
    "                  nn.Dense(256, 120)])\n",
    "    for name, cell in output_new.cells_and_names():\n",
    "        if isinstance(cell, nn.Dense):\n",
    "            k = 1 / cell.in_channels\n",
    "            k = k ** 0.5\n",
    "\n",
    "            cell.weight.set_data(\n",
    "                initializer.initializer(initializer.Uniform(k), cell.weight.shape, cell.weight.dtype))\n",
    "            if cell.bias is not None:\n",
    "                cell.bias.set_data(\n",
    "                    initializer.initializer(initializer.Uniform(k), cell.bias.shape, cell.bias.dtype))\n",
    "\n",
    "    finetune_net.append(output_new)\n",
    "    #finetune_net.append(output_new)\n",
    "    # 冻结参数\n",
    "    for param in finetune_net.feature.get_parameters():\n",
    "        param.requires_grad = False\n",
    "    return finetune_net"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "003bc622-826f-4010-95bf-e60986d42da2",
   "metadata": {},
   "outputs": [],
   "source": [
    "net = get_net(None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "81db4534-f4f7-48d6-8d2c-c9c6e2e3f024",
   "metadata": {},
   "outputs": [],
   "source": [
    "loss = nn.CrossEntropyLoss(reduction='none')\n",
    "\n",
    "def evaluate_loss(data_iter, net):\n",
    "    l_sum, n = 0.0, 0\n",
    "    for features, labels in data_iter:\n",
    "        outputs = net(features)\n",
    "        l = loss(outputs, labels)\n",
    "        l_sum += l.sum()\n",
    "        n += labels.numel()\n",
    "    return (l_sum / n)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "7888dba5-168c-439c-9cdb-b5ad4ea7867f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(net, train_iter, valid_iter, num_epochs, lr, wd, lr_period, lr_decay):\n",
    "    devices = None\n",
    "    lr_list = d2l.tensor([lr*(lr_decay**(i//lr_period)) \n",
    "                          for i in range(num_epochs) \n",
    "                          for j in range(train_iter.get_dataset_size())])\n",
    "    trainer = nn.SGD((param for param in net.get_parameters() if param.requires_grad), \n",
    "                     learning_rate=lr_list, momentum=0.9, weight_decay=wd)\n",
    "\n",
    "    def forward_fn(inputs, targets):\n",
    "        logits = net(inputs)\n",
    "        # print(logits.shape, targets.shape)\n",
    "        l = loss(logits, targets)\n",
    "        return l, logits\n",
    "    \n",
    "    grad_fn = mindspore.value_and_grad(forward_fn, None, trainer.parameters, has_aux=True)\n",
    "    \n",
    "    def train_step(inputs, targets):\n",
    "        (l, logits), grads = grad_fn(inputs, targets)\n",
    "        trainer(grads)\n",
    "        return l.sum(), logits\n",
    "    \n",
    "    num_batches, timer = train_iter.get_dataset_size(), d2l.Timer()\n",
    "    legend = ['train loss']\n",
    "    if valid_iter is not None:\n",
    "        legend.append('valid loss')\n",
    "    animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],\n",
    "                            legend=legend)\n",
    "\n",
    "    for epoch in range(num_epochs):\n",
    "        net.set_train()\n",
    "        metric = d2l.Accumulator(2)\n",
    "        for i, (features, labels) in enumerate(train_iter):\n",
    "            timer.start()\n",
    "            l, logits = train_step(features, labels)\n",
    "            metric.add(l, labels.shape[0])\n",
    "            timer.stop()\n",
    "            if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:\n",
    "                animator.add(epoch + (i + 1) / num_batches,\n",
    "                             (metric[0] / metric[1], None))\n",
    "        measures = f'train loss {metric[0] / metric[1]:.3f}'\n",
    "        if valid_iter is not None:\n",
    "            valid_loss = evaluate_loss(valid_iter, net)\n",
    "            animator.add(epoch + 1, (None, valid_loss))\n",
    "\n",
    "    if valid_iter is not None:\n",
    "        measures += f', valid loss {float(valid_loss):.3f}'\n",
    "    print(measures + f'\\n{metric[1] * num_epochs / timer.sum():.1f}'\n",
    "          f' examples/sec on {str(devices)}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "5c29bf7a-c935-43e2-988c-843c45b66331",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss 1.468, valid loss 1.511\n",
      "13.1 examples/sec on None\n"
     ]
    }
   ],
   "source": [
    "num_epochs, lr, wd = 10, 1e-4, 1e-4\n",
    "lr_period, lr_decay, net = 2, 0.9, get_net(None)\n",
    "train(net, train_iter, valid_iter, num_epochs, lr, wd, lr_period,\n",
    "      lr_decay)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "0bb97ded-47b7-4a77-8515-b6fa480f8d06",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train loss 1.520\n",
      "14.3 examples/sec on None\n"
     ]
    }
   ],
   "source": [
    "net, preds = get_net(None), []\n",
    "train(net, train_valid_iter, None, num_epochs, lr, wd, lr_period,\n",
    "      lr_decay)\n",
    "\n",
    "for X, _ in test_iter:\n",
    "    y_hat = net(X)\n",
    "    preds.extend(y_hat.argmax(axis=1).numpy().astype('int32'))\n",
    "sorted_ids = list(range(1, test_ds.get_dataset_size() + 1))\n",
    "sorted_ids.sort(key=lambda x: str(x))\n",
    "\n",
    "df = pd.DataFrame({'id': sorted_ids, 'label': preds})\n",
    "class_indexing = train_ds.get_class_indexing()\n",
    "classes = sorted(class_indexing.items(), key=lambda x: x[0])\n",
    "df['label'] = df['label'].apply(lambda x: classes[x][0])\n",
    "df.to_csv('submission.csv', index=False)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "MindSpore",
   "language": "python",
   "name": "mindspore"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
